{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 实验六  线性代数方程组的直接解法\n",
    "\n",
    "顺序高斯消去法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def gaussian_elimination_sequence(a, b):\n",
    "    \"\"\"\n",
    "    用「顺序高斯消去法」解线性方程 `ax = b`。\n",
    "    \n",
    "    Args:\n",
    "        a : np_array_like 系数矩阵（nxn）\n",
    "        b : np_array_like 右端常数（n）\n",
    "    \n",
    "    Returns:\n",
    "        x : np.array `ax=b` 的解（n）\n",
    "        \n",
    "    Raises:\n",
    "        Exception(\"求解失败\") if a[k][k] == 0\n",
    "    \"\"\"\n",
    "    A = np.c_[a, b]  # 增广矩阵\n",
    "    \n",
    "    n, c = A.shape\n",
    "    assert c == n + 1, f'bad shape: {A.shape}'\n",
    "    \n",
    "    x = np.zeros(n)\n",
    "\n",
    "    # 消元\n",
    "    for k in range(n-1):\n",
    "        if A[k][k] == 0:\n",
    "            raise Exception(\"求解失败\")\n",
    "\n",
    "        for i in range(k+1, n):\n",
    "            m = A[i][k] / A[k][k];\n",
    "            A[i][k] = 0;\n",
    "            for j in range(k+1, n+1):\n",
    "                A[i][j] -= A[k][j] * m\n",
    "\n",
    "    # 回代\n",
    "    x[n-1] = A[n-1][n] / A[n-1][n-1]\n",
    "    for k in range(n-2, -1, -1): # from n-2 (included) to 0 (included)\n",
    "        for j in range(k+1, n):\n",
    "            A[k][n] -= A[k][j] * x[j]\n",
    "        x[k] = A[k][n] / A[k][k]\n",
    "        \n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.39823377,  0.01379507,  0.33514424])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q = np.array([[0.101, 2.304, 3.555, 1.183],\n",
    "              [-1.347, 3.712, 4.623, 2.137], \n",
    "              [-2.835, 1.072, 5.643, 3.035]])\n",
    "a, b = q[:, :-1],  q[:, -1]\n",
    "gaussian_elimination_sequence(a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.39823377,  0.01379507,  0.33514424])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 用 numpy 计算检查\n",
    "np.linalg.solve(a, b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "列主元高斯消去算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gaussian_elimination(a, b, sequence=False):\n",
    "    \"\"\"\n",
    "    用「列主元高斯消去法」解线性方程 `ax = b`。\n",
    "    \n",
    "    本函数可以通过「列主元高斯消元法」或「顺序高斯消元法」计算，\n",
    "    通过参数 sequence 控制，默认 sequence=False 使用「列主元高斯消元法」。\n",
    "    \n",
    "    Args:\n",
    "        a : np_array_like 系数矩阵（nxn）\n",
    "        b : np_array_like 右端常数（n）\n",
    "    \n",
    "    Returns:\n",
    "        x : np.array `ax=b` 的解（n）\n",
    "        \n",
    "    Raises:\n",
    "        Exception(\"求解失败\") if a[k][k] == 0\n",
    "    \"\"\"\n",
    "    A = np.c_[a, b]  # 增广矩阵\n",
    "    \n",
    "    n, c = A.shape\n",
    "    assert c == n + 1, f'bad shape: {A.shape}'\n",
    "    \n",
    "    x = np.zeros(n)\n",
    "    \n",
    "    # 消元\n",
    "    for k in range(0, n-1):\n",
    "        if not sequence:  # 列主元\n",
    "            i_max = k + np.argmax(np.abs(A[k:n, k]))\n",
    "\n",
    "            if A[i_max][k] == 0:\n",
    "                raise Exception(\"A 奇异，求解失败\")\n",
    "\n",
    "            A[[i_max, k]] = A[[k, i_max]]  # swap rows\n",
    "        \n",
    "        for i in range(k+1, n):\n",
    "            m = A[i][k] / A[k][k];\n",
    "            for j in range(k+1, n+1):\n",
    "                A[i][j] -= A[k][j] * m\n",
    "    \n",
    "    # 回代\n",
    "    if a[n-1][n-1] == 0:\n",
    "        raise Exception(\"矩阵奇异，求解失败\")\n",
    "    \n",
    "    x[n-1] = A[n-1][n] / A[n-1][n-1]\n",
    "    for i in range(n-2, -1, -1): # from n-2 (included) to 0 (included)\n",
    "        for j in range(i+1, n):\n",
    "            A[i][n] -= A[i][j] * x[j]\n",
    "        x[i] = A[i][n] / A[i][i]\n",
    "        \n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "列主元高斯消元:\t [-0.39823377  0.01379507  0.33514424]\n",
      "顺序高斯消元:\t [-0.39823377  0.01379507  0.33514424]\n"
     ]
    }
   ],
   "source": [
    "q = np.array([[0.101, 2.304, 3.555, 1.183],\n",
    "              [-1.347, 3.712, 4.623, 2.137], \n",
    "              [-2.835, 1.072, 5.643, 3.035]])\n",
    "a, b = q[:, :-1],  q[:, -1]\n",
    "print('列主元高斯消元:\\t', gaussian_elimination(a, b))\n",
    "print('顺序高斯消元:\\t', gaussian_elimination(a, b, sequence=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LU 分解"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lu(a, sequence=False, swap_times: list = {}):\n",
    "    \"\"\"\n",
    "    「高斯消去法」的 LU 分解.\n",
    "    \n",
    "    本函数可以计算「列主元高斯消元法」、「顺序高斯消元法」的 LU 分解，\n",
    "    通过参数 sequence 控制，默认 sequence=False 使用「列主元高斯消元法」。\n",
    "    \n",
    "    Args:\n",
    "        a: np_array_like 方阵 (nxn)\n",
    "        sequence: bool, True 则使用顺序高斯消去法，False 为列主元的高斯消去法\n",
    "            default: sequence=False\n",
    "        swap_times: 这是一个**输出**用的变量，只有传入 dict 变量时才有效。\n",
    "            若使用「列主元高斯消元法」（sequence=False）\n",
    "            则，置 swap_times['swap_times'] = 行交换次数。\n",
    "            这个值正常的输出中不需要，但在一些问题，比如，\n",
    "            利用 LU 分解求行列式时，得到 swap_times 会很有帮助。\n",
    "    \n",
    "    Returns:\n",
    "        (l, u, p): result\n",
    "        \n",
    "        l: np.array, Lower triangle result (nxn)\n",
    "        u: np.array, Upper triangle result (nxn)\n",
    "        p: np.array, Permutation: 交换后的行顺序 (n)\n",
    "            p = None if sequence=True\n",
    "        \n",
    "    Raises:\n",
    "        Exception: 存在为零的主元素\n",
    "    \"\"\"\n",
    "    a = np.array(a, dtype=np.float)  # copy\n",
    "    \n",
    "    assert a.shape[0] == a.shape[1]\n",
    "    n = a.shape[0]\n",
    "    \n",
    "    \n",
    "    if not sequence:\n",
    "        # p 记录行交换的过程，使用「列主元高斯消元法」才使用，否则为 None\n",
    "        p = np.array([k for k in range(n)])\n",
    "        # swap_times:  行交换次数\n",
    "        if isinstance(swap_times, dict):\n",
    "            swap_times['swap_times'] = 0\n",
    "    else:\n",
    "        p = None\n",
    "        \n",
    "    \n",
    "    for k in range(n-1):\n",
    "        if not sequence:\n",
    "            i_max = k + np.argmax(np.abs(a[k:n, k]))\n",
    "\n",
    "            if i_max != k:\n",
    "                a[[i_max, k]] = a[[k, i_max]]  # swap rows\n",
    "                p[[i_max, k]] = p[[k, i_max]]  # record\n",
    "                swap_times['swap_times'] += 1\n",
    "        \n",
    "        if a[k][k] == 0:\n",
    "            raise Exception(\"存在为零的主元素\")\n",
    "            \n",
    "        for i in range(k+1, n):\n",
    "            a[i][k] /= a[k][k]  # L @ 严格下三角\n",
    "            for j in range(k+1, n):\n",
    "                a[i][j] -= a[i][k] * a[k][j]  # U @ 上三角\n",
    "    \n",
    "    # print(a, p)\n",
    "    \n",
    "    # Uncommit the following lines to get a Permutation Matrix\n",
    "    # Only for sequence=False\n",
    "    # pm = np.zeros_like(a)\n",
    "    # for i, v in enumerate(p):\n",
    "    #     pm[i, v] = 1\n",
    "    #     print(pm)\n",
    "    \n",
    "    return np.tril(a, k=-1) + np.identity(a.shape[0]), np.triu(a), p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[1.  , 0.  , 0.  ],\n",
       "        [1.  , 1.  , 0.  ],\n",
       "        [0.25, 0.5 , 1.  ]]),\n",
       " array([[4. , 4. , 2. ],\n",
       "        [0. , 2. , 2. ],\n",
       "        [0. , 0. , 0.5]]),\n",
       " array([1, 2, 0]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 列主元高斯消元\n",
    "lu([[1,2,2], [4,4,2], [4,6,4]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[1. , 0. , 0. ],\n",
       "        [4. , 1. , 0. ],\n",
       "        [4. , 0.5, 1. ]]),\n",
       " array([[ 1.,  2.,  2.],\n",
       "        [ 0., -4., -6.],\n",
       "        [ 0.,  0., -1.]]),\n",
       " None)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 顺序高斯消元\n",
    "lu([[1,2,2], [4,4,2], [4,6,4]], sequence=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_lu(b, l, u, p=None):\n",
    "    \"\"\"用 lu(a) 得到的 `pa=lu` 分解的结果求解原方程组 `ax=b` 的解 x。\n",
    "    \n",
    "    若 p 不为 None 则使用「列主元高斯消元」，p 为 None表示使用「顺序高斯消元」。\n",
    "        \n",
    "        # `@` means matrix multiplication, refer: https://docs.python.org/reference/expressions.html#binary-arithmetic-operations\n",
    "        b = p @ b if p != None\n",
    "        l @ y = b\n",
    "        u @ x = y\n",
    "    \n",
    "    Args:\n",
    "        b: np_array_like, 原方程组的右端常数（n）\n",
    "        l: np_array_like, Lower triangle of lu_seq(a)\n",
    "        u: np_array_like, Upper triangle of lu_seq(a)\n",
    "        p: np_array_like, LU分解中交换后的行顺序\n",
    "            default p=None: 未做行交换，即使用顺序高斯消去法\n",
    "        \n",
    "        使用列主元高斯消元法时，l, u, p 使用 lu(a) 得到的结果即可：\n",
    "            solve_lu(b, *lu(a))\n",
    "        或者使用顺序高斯消元：\n",
    "            solve_lu(b, *lu(a, sequence=True))  # p=None\n",
    "        \n",
    "    Returns:\n",
    "        x : np.array `ax=b` 的解（n）\n",
    "    \"\"\"\n",
    "    assert np.shape(l) == np.shape(u)\n",
    "    assert np.shape(l)[0] == np.shape(b)[0]\n",
    "    \n",
    "    n = np.shape(l)[0]\n",
    "    \n",
    "    # do swap\n",
    "    if p is not None:\n",
    "        b = [b[v] for v in p]\n",
    "\n",
    "    # L * y = b\n",
    "    y = np.zeros(n, dtype=np.float)\n",
    "    y[0] = b[0]\n",
    "    for i in range(1, n):\n",
    "        bi = b[i]\n",
    "        for j in range(0, i):\n",
    "            bi -= y[j] * l[i][j]\n",
    "        y[i] = bi / l[i][i]\n",
    "    # print(y)\n",
    "\n",
    "    # U * x = y\n",
    "    x = np.zeros(n, dtype=np.float)\n",
    "    x[n-1] = y[n-1] / u[n-1][n-1]\n",
    "    for i in range(n-2, -1, -1): # from n-2 (included) to 0 (included)\n",
    "        yi = y[i]\n",
    "        for j in range(i+1, n):\n",
    "            yi -= x[j] * u[i][j]\n",
    "        x[i] = yi / u[i][i]\n",
    "    # print(x)\n",
    "    \n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.39823377  0.01379507  0.33514424]\n",
      "[-0.39823377  0.01379507  0.33514424]\n"
     ]
    }
   ],
   "source": [
    "q = np.array([[0.101, 2.304, 3.555, 1.183],\n",
    "              [-1.347, 3.712, 4.623, 2.137], \n",
    "              [-2.835, 1.072, 5.643, 3.035]])\n",
    "a, b = q[:, :-1],  q[:, -1]\n",
    "\n",
    "# 列主元高斯消元\n",
    "x = solve_lu(b, *lu(a)); print(x)\n",
    "# 顺序高斯消元\n",
    "x = solve_lu(b, *lu(a, sequence=True)); print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LU 分解求行列式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def det(a):\n",
    "    \"\"\"矩阵行列式\n",
    "    \n",
    "    利用 LU 分解（列主元高斯消元）求方阵 a 的行列式: d = det(a)\n",
    "    \n",
    "    Args:\n",
    "        a: np_array_like, 要求行列式的矩阵\n",
    "        \n",
    "    Returns:\n",
    "        d: float, a 的行列式值。\n",
    "        \n",
    "    Reference:\n",
    "        https://blog.csdn.net/nstarLDS/article/details/106074256\n",
    "    \"\"\"\n",
    "    swap_times = {}\n",
    "    l, u, p = lu(a, sequence=False, swap_times=swap_times)\n",
    "    sign = -1 if swap_times['swap_times'] % 2 == 1 else 1\n",
    "    return np.prod(np.diag(u)) * sign"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21.20912390399999\n",
      "21.209123904000002\n"
     ]
    }
   ],
   "source": [
    "a = [[ 0.101, 2.304, 3.555],\n",
    "     [-1.347, 3.712, 4.623],\n",
    "     [-2.835, 1.072, 5.643]]\n",
    "print(np.linalg.det(a))\n",
    "print(det(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-2.0000000000000004\n",
      "-2.0\n"
     ]
    }
   ],
   "source": [
    "a = [[1,2], [3,4]]\n",
    "print(np.linalg.det(a))\n",
    "print(det(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "np.det: -0.01943799105427653\n",
      "pd dia: -0.01943799105427655\n",
      "my det: -0.01943799105427655\n",
      "{'swap_times': 2}\n",
      "----\n",
      "np.det: 0.0476193547152672\n",
      "pd dia: 0.047619354715267216\n",
      "my det: 0.047619354715267216\n",
      "{'swap_times': 2}\n",
      "----\n",
      "np.det: 0.0017777819144817564\n",
      "pd dia: -0.001777781914481734\n",
      "my det: 0.001777781914481734\n",
      "{'swap_times': 1}\n",
      "----\n",
      "np.det: -0.058467358804227446\n",
      "pd dia: 0.058467358804227446\n",
      "my det: -0.058467358804227446\n",
      "{'swap_times': 1}\n",
      "----\n",
      "np.det: -0.0026315646257671272\n",
      "pd dia: -0.002631564625767127\n",
      "my det: -0.002631564625767127\n",
      "{'swap_times': 2}\n",
      "----\n"
     ]
    }
   ],
   "source": [
    "def test_det(dim):\n",
    "    a = np.random.random((dim, dim))\n",
    "    print('np.det:', np.linalg.det(a))\n",
    "    s = {}\n",
    "    l, u, p = lu(a, sequence=False, swap_times=s)\n",
    "    print('pd dia:', np.prod(np.diag(u)))\n",
    "    print('my det:', det(a))\n",
    "    print(s)\n",
    "for i in range(5):\n",
    "    test_det(3)\n",
    "    print('----')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LU 分解求逆"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inv(a):\n",
    "    \"\"\"矩阵的逆\n",
    "    \n",
    "    利用 LU 分解（列主元高斯消元）求方阵 a 的逆矩阵: x = inv(a)\n",
    "    \n",
    "    Args:\n",
    "        a: np_array_like, 待求逆的矩阵\n",
    "        \n",
    "    Returns:\n",
    "        x: float, a 的逆矩阵\n",
    "        \n",
    "    Reference:\n",
    "        https://en.wikipedia.org/wiki/LU_decomposition#Inverting_a_matrix\n",
    "    \"\"\"\n",
    "    l, u, p = lu(a)\n",
    "    \n",
    "    X = np.zeros_like(a, dtype=np.float)\n",
    "    B = np.identity(np.shape(a)[0])\n",
    "    \n",
    "    for i, b in enumerate(B.T):  # iter on cols\n",
    "        x = solve_lu(b, l, u, p)\n",
    "        X[:, i] = x\n",
    "        \n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 0., 0.],\n",
       "       [0., 1., 0.],\n",
       "       [0., 0., 1.]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inv(np.identity(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-2.   1. ]\n",
      " [ 1.5 -0.5]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-2. ,  1. ],\n",
       "       [ 1.5, -0.5]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = [[1,2], [3,4]]\n",
    "print(np.linalg.inv(a))\n",
    "inv(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "# 多尝试一些，看看是否正确\n",
    "\n",
    "err = 0\n",
    "\n",
    "for i in range(100):\n",
    "    a = np.random.random((i, i))\n",
    "    \n",
    "    inv_np = np.linalg.inv(a)\n",
    "    inv_my = inv(a)\n",
    "\n",
    "    if not np.all(inv_my - inv_np < 1e-8):\n",
    "        err += 1\n",
    "    \n",
    "print(err)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试运行速度\n",
    "\n",
    "import time\n",
    "\n",
    "def time_cost(func):\n",
    "    \"\"\"print time cost of a function func\"\"\"\n",
    "    def wrapper(*args, **kwargs):\n",
    "        s = time.time()\n",
    "        ret = func(*args, **kwargs)\n",
    "        e = time.time() - s\n",
    "        print(f'{func.__name__} executed in {e: .3f} s')\n",
    "        return ret\n",
    "    return wrapper\n",
    "\n",
    "@time_cost\n",
    "def np_inv(a):\n",
    "    return np.linalg.inv(a)\n",
    "\n",
    "@time_cost\n",
    "def my_inv(a):\n",
    "    return inv(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "np_inv executed in  0.005 s\n",
      "my_inv executed in  3.866 s\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "a = np.random.random((150, 150))\n",
    "\n",
    "inv_np = np_inv(a)\n",
    "inv_my = my_inv(a)\n",
    "\n",
    "print(np.all(inv_my - inv_np < 1e-8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = [[0.3e-15, 59.14, 3,1],\n",
    "     [5.29, -6.13, -1, 2],\n",
    "     [11.2, 9, 5, 2], \n",
    "     [1, 2, 1, 1]]\n",
    "b = [59.17, 46.78, 1, 2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参考的正确解："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  3.84604049,   1.60956057, -15.47712511,  10.41196349])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linalg.solve(A, b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "顺序高斯消去："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.7/site-packages/ipykernel_launcher.py:36: RuntimeWarning: invalid value encountered in double_scalars\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([nan, nan, nan, nan])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gaussian_elimination(A, b, sequence=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "列主元高斯消元:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  3.84604049,   1.60956057, -15.47712511,  10.41196349])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gaussian_elimination(A, b, sequence=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LU 分解：\n",
    "\n",
    "顺序："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "ename": "Exception",
     "evalue": "存在为零的主元素",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mException\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-29-a2c3103a4e3e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msequence\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'l={l}\\nu={u}\\np={p}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolve_lu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'x={x}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-3-d83f8c6bb23c>\u001b[0m in \u001b[0;36mlu\u001b[0;34m(a, sequence, swap_times)\u001b[0m\n\u001b[1;32m     53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"存在为零的主元素\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mException\u001b[0m: 存在为零的主元素"
     ]
    }
   ],
   "source": [
    "l, u, p = lu(A, sequence=True); print(f'l={l}\\nu={u}\\np={p}')\n",
    "x = solve_lu(b, l, u, p); print(f'x={x}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "列主元："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "l=[[ 1.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.67857143e-17  1.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 4.72321429e-01 -1.75530823e-01  1.00000000e+00  0.00000000e+00]\n",
      " [ 8.92857143e-02  2.02304459e-02 -1.73854511e-01  1.00000000e+00]]\n",
      "u=[[11.2         9.          5.          2.        ]\n",
      " [ 0.         59.14        3.          1.        ]\n",
      " [ 0.          0.         -2.83501467  1.23088797]\n",
      " [ 0.          0.          0.          1.01519355]]\n",
      "p=[2 0 1 3]\n",
      "x=[  3.84604049   1.60956057 -15.47712511  10.41196349]\n"
     ]
    }
   ],
   "source": [
    "l, u, p = lu(A); print(f'l={l}\\nu={u}\\np={p}')\n",
    "x = solve_lu(b, l, u, p); print(f'x={x}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# "
   ]
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
